package com.lw.leetcode.binary.c;

/**
 * Created with IntelliJ IDEA.
 * 2040. 两个有序数组的第 K 小乘积
 *
 * @author liw
 * @version 1.0
 * @date 2023/2/3 16:30
 */
public class KthSmallestProduct {

    public static void main(String[] args) {
        KthSmallestProduct test = new KthSmallestProduct();

        // 8
        int[] nums1 = {2, 5};
        int[] nums2 = {3, 4};
        int k = 2;

        // 0
//        int[] nums1 = {-4, -2, 0, 3};
//        int[] nums2 = {2, 4};
//        int k = 6;

        // -6
//        int[] nums1 = {-2, -1, 0, 1, 2};
//        int[] nums2 = {-3, -1, 2, 4, 5};
//        int k = 3;

        long l = test.kthSmallestProduct(nums1, nums2, k);
        System.out.println(l);

    }

    public long kthSmallestProduct(int[] nums1, int[] nums2, long k) {

        int length1 = nums1.length;
        int r1 = gtZero(nums1);
        int l1 = ltZero(nums1);
        int a1 = l1 + 1;
        int b1 = r1 - l1 - 1;
        int c1 = length1 - r1;

        int length2 = nums2.length;
        int r2 = gtZero(nums2);
        int l2 = ltZero(nums2);
        int a2 = l2 + 1;
        int b2 = r2 - l2 - 1;
        int c2 = length2 - r2;

        int[] arr1 = new int[a1];
        int[] arr2 = new int[c2];
        int[] arr3 = new int[a2];
        int[] arr4 = new int[c1];
        for (int i = 0; i < a1; i++) {
            arr1[i] = -nums1[a1 - i - 1];
        }
        if (c2 != 0) {
            System.arraycopy(nums2, length2 - c2, arr2, 0, c2);
        }
        for (int i = 0; i < a2; i++) {
            arr3[i] = -nums2[a2 - i - 1];
        }
        if (c1 != 0) {
            System.arraycopy(nums1, length1 - c1, arr4, 0, c1);
        }
        long s1 = (long) a1 * c2 + (long) a2 * c1;
        if (s1 >= k) {
            return -find(arr1, arr2, arr3, arr4, s1 - k + 1);
        }
        s1 = s1 + (long) b1 * length2 + (long) b2 * (length1 - b1);
        if (s1 >= k) {
            return 0;
        }
        return find(arr4, arr2, arr1, arr3, k - s1);

    }

    public long find(int[] nums1, int[] nums2, int[] nums3, int[] nums4, long k) {
        long st = 1L;
        int l1 = nums1.length;
        int l2 = nums2.length;
        int l3 = nums3.length;
        int l4 = nums4.length;

        long end = 0;
        if (nums1.length != 0 && nums2.length != 0) {
            end = (long) nums1[l1 - 1] * nums2[l2 - 1];
        }
        if (nums3.length != 0 && nums4.length != 0) {
            end = Math.max(end, (long) nums3[l3 - 1] * nums4[l4 - 1]);
        }
        long m = 0;
        while (st < end) {
            m = st + ((end - st) >> 1);
            long count = 0;
            if (nums1.length != 0 && nums2.length != 0) {
                int i = 0;
                int j = l2 - 1;
                while (i < l1 && j >= 0) {
                    if ((long) nums1[i] * nums2[j] <= m) {
                        count += (j + 1);
                        i++;
                    } else {
                        j--;
                    }
                }
            }
            if (nums3.length != 0 && nums4.length != 0) {
                int i = 0;
                int j = l4 - 1;
                while (i < l3 && j >= 0) {
                    if ((long) nums3[i] * nums4[j] <= m) {
                        count += (j + 1);
                        i++;
                    } else {
                        j--;
                    }
                }
            }

            if (count >= k) {
                end = m;
            } else {
                st = m + 1;
            }
        }
        return st;
    }


    private int gtZero(int[] nums) {
        int end = nums.length - 1;
        if (nums[end] <= 0) {
            return end + 1;
        }
        int st = 0;
        while (st < end) {
            int m = st + ((end - st) >> 1);
            if (nums[m] > 0) {
                end = m;
            } else {
                st = m + 1;
            }
        }
        return st;
    }

    private int ltZero(int[] nums) {
        int end = nums.length - 1;
        if (nums[0] >= 0) {
            return -1;
        }
        int st = 0;
        while (st < end) {
            int m = st + ((end - st + 1) >> 1);
            if (nums[m] < 0) {
                st = m;
            } else {
                end = m - 1;
            }
        }
        return st;
    }

}
